{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## self attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import time\n",
    "import random\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "VOCAB_SIZE = 14_828\n",
    "\n",
    "EPOCHS = 5\n",
    "BATCH_SIZE = 32\n",
    "LEARNING_RATE = 0.01\n",
    "BEST_VALID_LOSS = float('inf')\n",
    "\n",
    "EMBEDDING_DIM = 100\n",
    "OUTPUT_DIM = 1\n",
    "\n",
    "train_file = \"data/senti.train.tsv\"\n",
    "eval_file = \"data/senti.dev.tsv\"\n",
    "test_file = \"data/senti.test.tsv\"\n",
    "\n",
    "USE_CUDA = torch.cuda.is_available()\n",
    "DEVICE = torch.device('cuda:1' if USE_CUDA else 'cpu')\n",
    "NUM_CUDA = torch.cuda.device_count()\n",
    "\n",
    "random.seed(2019)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_text_file(filename):\n",
    "    \"\"\"将样本的特征与标签分开，并将样本特征分词\"\"\"\n",
    "    sentences = []\n",
    "    label = []\n",
    "    with open(filename, \"r\") as f:\n",
    "        sent_list  = [line.strip().split('\\t') for line in f]\n",
    "    for sample in sent_list:\n",
    "        sentences.append(sample[0].lower().split(\" \"))\n",
    "        label.append(int(sample[-1]))\n",
    "    return sentences, label\n",
    "\n",
    "\n",
    "def build_word_dic(sentences_list, vocab_size=20_000):\n",
    "    \"\"\"构建words_set, word2idx, idx2word\"\"\"\n",
    "    words_list = [w for line in sentences_list for w in line]\n",
    "    counter = Counter(words_list)\n",
    "    words_topn = counter.most_common(vocab_size)\n",
    "    words_set = [item[0] for item in words_topn]\n",
    "    words_set = ['<pad>', \"<unk>\"] + words_set\n",
    "    word2idx = {w:i for i, w in enumerate(words_set)}\n",
    "    idx2word = {i:w for i, w in enumerate(words_set)}\n",
    "    return words_topn, word2idx, idx2word\n",
    "\n",
    "\n",
    "def build_x_y(word2idx, sentences_list, label_list, seq_len=60):\n",
    "    \"\"\"构建输入模型的数据，对每个单词编码，每个句子通过添加pading保持一样长\"\"\"\n",
    "    x = []\n",
    "    y = []\n",
    "    for sent, label in zip(sentences_list, label_list):\n",
    "        word_x = [0]* seq_len\n",
    "        for i, w in enumerate(sent):\n",
    "            if w in word2idx:\n",
    "                word_x[i] = word2idx[w]\n",
    "            else:\n",
    "                word_x[i] = word2idx['<unk>']\n",
    "        x.append(word_x)\n",
    "        y.append(label)\n",
    "    return x, y\n",
    "\n",
    "\n",
    "def build_batch_data(data, label, batch_size=32):\n",
    "    \"\"\"构建tensor格式的批次数据，返回batch列表，每个batch为二元组包含feature和label\"\"\"\n",
    "    batch_data = []\n",
    "    # 打乱顺序\n",
    "    data_labels = [[x, y] for x, y in zip(data, label)]\n",
    "    random.shuffle(data_labels)\n",
    "    xlist = [item[0] for item in data_labels]\n",
    "    ylist = [item[1] for item in data_labels]\n",
    "    \n",
    "    x_tensor = torch.tensor(xlist, dtype=torch.long)\n",
    "    y_tensor = torch.tensor(ylist, dtype=torch.float)\n",
    "    n, dim = x_tensor.size()\n",
    "    for start in range(0, n, batch_size):\n",
    "        end = start + batch_size\n",
    "        if end > n:\n",
    "            xbatch = x_tensor[start: ]\n",
    "            ybatch = y_tensor[start: ]\n",
    "            print(\"最后一个batch size:\", ybatch.size())\n",
    "        else:\n",
    "            xbatch = x_tensor[start: end]\n",
    "            ybatch = y_tensor[start: end]\n",
    "        batch_data.append((xbatch, ybatch))\n",
    "    return batch_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SelfAttModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_dim, output_size, pad_idx):\n",
    "        super(SelfAttModel, self).__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)\n",
    "        initrange = 0.1\n",
    "        self.embedding.weight.data.uniform_(-initrange, initrange)\n",
    "        # 计算 Attention 向量\n",
    "        self.qkv = nn.Linear(embed_dim, embed_dim, bias=False)\n",
    "        self.fc = nn.Linear(embed_dim, output_size, bias=False)\n",
    "        \n",
    "    def forward(self, text):\n",
    "        # [batch, seq_len] -> [batch, seq_len, emb_dim]\n",
    "        embed = self.embedding(text)\n",
    "        # [batch, seq_len, emb_dim] -> [batch, seq_len, embed_dim]?\n",
    "        x = self.qkv(embed)     # 用emeding层产生？\n",
    "        # 算句子Attention平均值\n",
    "        h_attn = self.attention(x)   # [batch, seq_len, emb_dim]\n",
    "        # 平均值\n",
    "        # [batch, seq_len, emb_dim] -> [batch, emb_dim]  # 每个句子求平均值得到一个词向量\n",
    "        h_attn = torch.sum(h_attn, dim=1).squeeze()\n",
    "        # [batch, emb_dim] --> [batch, output_size]\n",
    "        out = self.fc(h_attn)\n",
    "        return out\n",
    "    \n",
    "    def attention(self, x):\n",
    "        \"\"\"计算attention权重\"\"\"\n",
    "        d_k = x.size(-1)    # embed_dim\n",
    "        # x.transpose(-2, -1) 后两维度的转置 [batch, seq_len, emb_dim] --> [batch, emb_dim, seq_len]\n",
    "        # [batch, seq_len, emb_dim] -> [batch, seq_len, seq_len]\n",
    "        scores = torch.matmul(x, x.transpose(-2, -1)) / math.sqrt(d_k)\n",
    "        # [batch, seq_len, seq_len] ->[batch, seq_len, seq_len]\n",
    "        attn = F.softmax(scores, dim=-1)\n",
    "        # 计算context值 \n",
    "        # [batch, seq_len, seq_len] -> [batch, seq_len, emb_dim]\n",
    "        attn_x = torch.matmul(attn, x)\n",
    "        return attn_x\n",
    "    \n",
    "    def get_embed_weight(self):\n",
    "        \"\"\"获取embedding层参数\"\"\"\n",
    "        return self.embedding.weight.data\n",
    "\n",
    "\n",
    "def binary_accuracy(preds, y):\n",
    "    \"\"\"计算准确率\"\"\"\n",
    "    rounded_preds = torch.round(torch.sigmoid(preds))\n",
    "    correct = (rounded_preds == y).float()  \n",
    "    acc = correct.sum()/len(correct)\n",
    "    return acc\n",
    "\n",
    "\n",
    "def train(model, device, iterator, optimizer, criterion):\n",
    "    \"\"\"训练函数\"\"\"\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    model.train()\n",
    "    \n",
    "    for x, y in iterator:\n",
    "        x, y = x.to(device), y.to(device) # torch.int64\n",
    "        optimizer.zero_grad()\n",
    "        predictions = model(x).squeeze(1)  # torch.float32 \n",
    "        \n",
    "        loss = criterion(predictions, y)\n",
    "        acc = binary_accuracy(predictions, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)\n",
    "\n",
    "\n",
    "def evaluate(model, device, iterator, criterion):\n",
    "    \"\"\"验证函数\"\"\"\n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    model.eval()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for x, y in iterator:\n",
    "            x, y = x.to(device), y.to(device)\n",
    "            predictions = model(x).squeeze(1)\n",
    "            loss = criterion(predictions, y)\n",
    "            acc = binary_accuracy(predictions, y)\n",
    "            epoch_loss += loss.item()\n",
    "            epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)\n",
    "\n",
    "\n",
    "def count_parameters(model):\n",
    "    \"\"\"统计模型的参数量\"\"\"\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "\n",
    "def epoch_time(start_time, end_time):\n",
    "    \"\"\"计算时间差，返回分钟, 秒钟\"\"\"\n",
    "    elapsed_time = end_time - start_time\n",
    "    elapsed_mins = int(elapsed_time / 60)\n",
    "    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
    "    return elapsed_mins, elapsed_secs\n",
    "\n",
    "def max_seq_len(data_list):\n",
    "    \"\"\"获取句子最大长度\"\"\"\n",
    "    li = [len(s) for data in data_list for s in data]\n",
    "    return max(li)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_len: 56\n",
      "词典长度: 14828 14830 14830\n",
      "训练集样本数量: 67349 67349\n",
      "最后一个batch size: torch.Size([21])\n",
      "最后一个batch size: torch.Size([8])\n",
      "最后一个batch size: torch.Size([29])\n"
     ]
    }
   ],
   "source": [
    "train_sentences, train_label = load_text_file(train_file)\n",
    "eval_sentences, eval_label = load_text_file(eval_file)\n",
    "test_sentences, test_label = load_text_file(test_file)\n",
    "\n",
    "max_len = max_seq_len([train_sentences, eval_sentences, test_sentences])\n",
    "print(\"max_len:\", max_len)\n",
    "words_set, word2idx, idx2word = build_word_dic(train_sentences, vocab_size=VOCAB_SIZE)\n",
    "train_x, train_y = build_x_y(word2idx, train_sentences, train_label, seq_len=max_len)\n",
    "eval_x, eval_y = build_x_y(word2idx, eval_sentences, eval_label, seq_len=max_len)\n",
    "test_x, test_y = build_x_y(word2idx, test_sentences, test_label, seq_len=max_len)\n",
    "\n",
    "print(\"词典长度:\", len(words_set), len(word2idx), len(idx2word))\n",
    "print(\"训练集样本数量:\", len(train_x), len(train_y))\n",
    "\n",
    "train_data = build_batch_data(train_x, train_y, batch_size=BATCH_SIZE)\n",
    "eval_data = build_batch_data(eval_x, eval_y, batch_size=BATCH_SIZE)\n",
    "test_data = build_batch_data(test_x, test_y, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型有1,493,100个可调节参数, 大约5.6957244873046875 M.\n"
     ]
    }
   ],
   "source": [
    "INPUT_DIM = len(words_set) + 2\n",
    "PAD_IDX = word2idx['<pad>']\n",
    "\n",
    "model = SelfAttModel(INPUT_DIM, EMBEDDING_DIM, OUTPUT_DIM, PAD_IDX)\n",
    "print(f'模型有{count_parameters(model):,}个可调节参数, 大约{count_parameters(model)*4/1024/1024} M.')\n",
    "\n",
    "model = model.to(DEVICE)\n",
    "    \n",
    "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
    "criterion = nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/lib/python3.6/site-packages/torch/serialization.py:251: UserWarning: Couldn't retrieve source code for container of type SelfAttModel. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "***Save Best Model self-attention-wordavg.pth***\n",
      "Epoch: 01 | Epoch Time: 0m 8s\n",
      "\tTrain Loss: 0.391 | Train Acc: 83.59%\n",
      "\t Val. Loss: 0.630 |  Val. Acc: 80.36%\n",
      "Epoch: 02 | Epoch Time: 0m 8s\n",
      "\tTrain Loss: 0.380 | Train Acc: 89.68%\n",
      "\t Val. Loss: 1.530 |  Val. Acc: 76.34%\n",
      "Epoch: 03 | Epoch Time: 0m 8s\n",
      "\tTrain Loss: 0.466 | Train Acc: 91.57%\n",
      "\t Val. Loss: 2.797 |  Val. Acc: 78.24%\n",
      "Epoch: 04 | Epoch Time: 0m 8s\n",
      "\tTrain Loss: 0.446 | Train Acc: 93.12%\n",
      "\t Val. Loss: 5.711 |  Val. Acc: 77.01%\n",
      "Epoch: 05 | Epoch Time: 0m 8s\n",
      "\tTrain Loss: 0.379 | Train Acc: 94.06%\n",
      "\t Val. Loss: 5.046 |  Val. Acc: 79.13%\n"
     ]
    }
   ],
   "source": [
    "model_name = 'self-attention-wordavg.pth'\n",
    "for epoch in range(1, EPOCHS+1):\n",
    "    start_time = time.time()\n",
    "    train_loss, train_acc = train(model, DEVICE, train_data, optimizer, criterion)\n",
    "    valid_loss, valid_acc = evaluate(model, DEVICE, eval_data, criterion)\n",
    "    end_time = time.time()\n",
    "\n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    if valid_loss < BEST_VALID_LOSS:\n",
    "        BEST_VALID_LOSS = valid_loss\n",
    "        torch.save(model, model_name)\n",
    "        print(f'***Save Best Model {model_name}***')\n",
    "    \n",
    "    print(f'Epoch: {epoch :02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 0.5601791255829627 | Test Acc: 0.8189088022499754\n"
     ]
    }
   ],
   "source": [
    "model = torch.load(model_name)\n",
    "test_loss, test_acc = evaluate(model, DEVICE, test_data, criterion)\n",
    "print('Test Loss: {0} | Test Acc: {1}'.format(test_loss, test_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add residual残差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AttentionResidualModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_dim, output_size, pad_idx):\n",
    "        super(AttentionResidualModel, self).__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)\n",
    "        initrange = 0.1\n",
    "        self.embedding.weight.data.uniform_(-initrange, initrange)\n",
    "        self.qkv = nn.Linear(embed_dim, embed_dim, bias=False)\n",
    "        self.fc = nn.Linear(embed_dim, output_size, bias=False)\n",
    "        self.dropout = nn.Dropout(0.2)\n",
    "        \n",
    "    def forward(self, text):\n",
    "        # [batch, seq_len] -> [batch, seq_len, emb_dim]\n",
    "        embed = self.embedding(text)\n",
    "        # [batch, seq_len, emb_dim] -> [batch, seq_len, embed_dim]?\n",
    "        x = self.qkv(embed)     # 用emeding层产生？\n",
    "        # 算句子Attention平均值\n",
    "        h_attn = self.attention(x)   # [batch, seq_len, emb_dim]\n",
    "        h_attn += embed\n",
    "        # 平均值\n",
    "        # [batch, seq_len, emb_dim] -> [batch, emb_dim]  # 每个句子求平均值得到一个词向量\n",
    "        h_attn = torch.sum(h_attn, dim=1).squeeze()\n",
    "        # [batch, emb_dim] --> [batch, output_size]\n",
    "        out = self.fc(self.dropout(h_attn))\n",
    "        return out\n",
    "    \n",
    "    def attention(self, x):\n",
    "        \"\"\"计算attention权重\"\"\"\n",
    "        d_k = x.size(-1)    # embed_dim\n",
    "        # x.transpose(-2, -1) 后两维度的转置 [batch, seq_len, emb_dim] --> [batch, emb_dim, seq_len]\n",
    "        # [batch, seq_len, emb_dim] -> [batch, seq_len, seq_len]\n",
    "        scores = torch.matmul(x, x.transpose(-2, -1)) / math.sqrt(d_k)\n",
    "        # [batch, seq_len, seq_len] ->[batch, seq_len, seq_len]\n",
    "        attn = F.softmax(scores, dim=-1)\n",
    "        # 计算context值 \n",
    "        # [batch, seq_len, seq_len] -> [batch, seq_len, emb_dim]\n",
    "        attn_x = torch.matmul(attn, x)\n",
    "        return attn_x\n",
    "    \n",
    "    def get_embed_weight(self):\n",
    "        \"\"\"获取embedding层参数\"\"\"\n",
    "        return self.embedding.weight.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型有1,493,100个可调节参数, 大约5.6957244873046875 M.\n"
     ]
    }
   ],
   "source": [
    "res_model = AttentionResidualModel(INPUT_DIM, EMBEDDING_DIM, OUTPUT_DIM, PAD_IDX)\n",
    "print(f'模型有{count_parameters(res_model):,}个可调节参数, 大约{count_parameters(res_model)*4/1024/1024} M.')\n",
    "\n",
    "res_model = res_model.to(DEVICE)\n",
    "\n",
    "optimizer = optim.Adam(res_model.parameters(), lr=LEARNING_RATE)\n",
    "criterion = nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/lib/python3.6/site-packages/torch/serialization.py:251: UserWarning: Couldn't retrieve source code for container of type AttentionResidualModel. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "***Save Best Model attention-residual-wordavg.pth***\n",
      "Epoch: 01 | Epoch Time: 0m 8s\n",
      "\tTrain Loss: 0.411 | Train Acc: 82.54%\n",
      "\t Val. Loss: 0.680 |  Val. Acc: 80.25%\n",
      "Epoch: 02 | Epoch Time: 0m 9s\n",
      "\tTrain Loss: 0.423 | Train Acc: 89.12%\n",
      "\t Val. Loss: 1.333 |  Val. Acc: 78.68%\n",
      "Epoch: 03 | Epoch Time: 0m 9s\n",
      "\tTrain Loss: 0.450 | Train Acc: 91.44%\n",
      "\t Val. Loss: 4.090 |  Val. Acc: 79.13%\n",
      "Epoch: 04 | Epoch Time: 0m 8s\n",
      "\tTrain Loss: 1.355 | Train Acc: 92.43%\n",
      "\t Val. Loss: 2.013 |  Val. Acc: 78.91%\n",
      "Epoch: 05 | Epoch Time: 0m 9s\n",
      "\tTrain Loss: 0.232 | Train Acc: 94.51%\n",
      "\t Val. Loss: 3.674 |  Val. Acc: 77.68%\n"
     ]
    }
   ],
   "source": [
    "res_model_name = 'attention-residual-wordavg.pth'\n",
    "BEST_VALID_LOSS = float('inf')\n",
    "\n",
    "for epoch in range(1, EPOCHS+1):\n",
    "    start_time = time.time()\n",
    "    train_loss, train_acc = train(res_model, DEVICE, train_data, optimizer, criterion)\n",
    "    valid_loss, valid_acc = evaluate(res_model, DEVICE, eval_data, criterion)\n",
    "    end_time = time.time()\n",
    "\n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    if valid_loss < BEST_VALID_LOSS:\n",
    "        BEST_VALID_LOSS = valid_loss\n",
    "        torch.save(res_model, res_model_name)\n",
    "        print(f'***Save Best Model {res_model_name}***')\n",
    "    \n",
    "    print(f'Epoch: {epoch :02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 0.6042031730737603 | Test Acc: 0.8194570478640104\n"
     ]
    }
   ],
   "source": [
    "res_model = torch.load(res_model_name)\n",
    "test_loss, test_acc = evaluate(res_model, DEVICE, test_data, criterion)\n",
    "print('Test Loss: {0} | Test Acc: {1}'.format(test_loss, test_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自己设置attention函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyAttentionModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_dim, output_size, pad_idx):\n",
    "        super(MyAttentionModel, self).__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)\n",
    "        initrange = 0.1\n",
    "        self.embedding.weight.data.uniform_(-initrange, initrange)\n",
    "        # 权重计算 q, v, k\n",
    "        self.q = nn.Linear(embed_dim, embed_dim, bias=False)\n",
    "        self.k = nn.Linear(embed_dim, embed_dim, bias=False)\n",
    "        self.v = nn.Linear(embed_dim, embed_dim, bias=False)\n",
    "        self.fc = nn.Linear(embed_dim, output_size, bias=False)\n",
    "        self.dropout = nn.Dropout(0.2)\n",
    "        \n",
    "    def forward(self, text):\n",
    "        # [batch, seq_len] -> [batch, seq_len, emb_dim]\n",
    "        embed = self.embedding(text)\n",
    "        # [batch, seq_len, emb_dim] -> [batch, seq_len, embed_dim]?\n",
    "        q_vec = self.q(embed) \n",
    "        k_vec = self.k(embed)\n",
    "        v_vec = self.v(embed)\n",
    "        # 算句子Attention平均值\n",
    "        h_attn = self.attention(q_vec, k_vec, v_vec)   # [batch, seq_len, emb_dim]\n",
    "        h_attn += embed\n",
    "        # 平均值\n",
    "        # [batch, seq_len, emb_dim] -> [batch, emb_dim]  # 每个句子求平均值得到一个词向量\n",
    "        h_attn = torch.sum(h_attn, dim=1).squeeze()\n",
    "        # [batch, emb_dim] --> [batch, output_size]\n",
    "        out = self.fc(self.dropout(h_attn))\n",
    "        return out\n",
    "    \n",
    "    def attention(self, q, k, v):\n",
    "        \"\"\"计算attention权重\"\"\"\n",
    "        d_k = k.size(-1)    # embed_dim\n",
    "        # [batch, seq_len, emb_dim] -> [batch, seq_len, seq_len]\n",
    "        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)\n",
    "        # [batch, seq_len, seq_len] ->[batch, seq_len, seq_len]\n",
    "        attn = F.softmax(scores, dim=-1)\n",
    "        # 计算context值 \n",
    "        # [batch, seq_len, seq_len] -> [batch, seq_len, emb_dim]\n",
    "        attn_x = torch.matmul(attn, v)\n",
    "        return attn_x\n",
    "    \n",
    "    def get_embed_weight(self):\n",
    "        \"\"\"获取embedding层参数\"\"\"\n",
    "        return self.embedding.weight.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型有1,513,100个可调节参数, 大约5.7720184326171875 M.\n"
     ]
    }
   ],
   "source": [
    "att_model = MyAttentionModel(INPUT_DIM, EMBEDDING_DIM, OUTPUT_DIM, PAD_IDX)\n",
    "print(f'模型有{count_parameters(att_model):,}个可调节参数, 大约{count_parameters(att_model)*4/1024/1024} M.')\n",
    "\n",
    "att_model = att_model.to(DEVICE)\n",
    "\n",
    "optimizer = optim.Adam(att_model.parameters(), lr=LEARNING_RATE)\n",
    "criterion = nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/lib/python3.6/site-packages/torch/serialization.py:251: UserWarning: Couldn't retrieve source code for container of type MyAttentionModel. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "***Save Best Model my-attention-wordavg.pth***\n",
      "Epoch: 01 | Epoch Time: 0m 10s\n",
      "\tTrain Loss: 0.461 | Train Acc: 82.26%\n",
      "\t Val. Loss: 0.669 |  Val. Acc: 78.12%\n",
      "Epoch: 02 | Epoch Time: 0m 10s\n",
      "\tTrain Loss: 0.453 | Train Acc: 88.23%\n",
      "\t Val. Loss: 1.575 |  Val. Acc: 76.79%\n",
      "Epoch: 03 | Epoch Time: 0m 11s\n",
      "\tTrain Loss: 0.599 | Train Acc: 88.79%\n",
      "\t Val. Loss: 2.247 |  Val. Acc: 76.45%\n",
      "Epoch: 04 | Epoch Time: 0m 11s\n",
      "\tTrain Loss: 0.655 | Train Acc: 90.28%\n",
      "\t Val. Loss: 3.758 |  Val. Acc: 79.24%\n",
      "Epoch: 05 | Epoch Time: 0m 10s\n",
      "\tTrain Loss: 3.910 | Train Acc: 87.12%\n",
      "\t Val. Loss: 9.320 |  Val. Acc: 74.11%\n"
     ]
    }
   ],
   "source": [
    "att_model_name = 'my-attention-wordavg.pth'\n",
    "BEST_VALID_LOSS = float('inf')\n",
    "\n",
    "for epoch in range(1, EPOCHS+1):\n",
    "    start_time = time.time()\n",
    "    train_loss, train_acc = train(att_model, DEVICE, train_data, optimizer, criterion)\n",
    "    valid_loss, valid_acc = evaluate(att_model, DEVICE, eval_data, criterion)\n",
    "    end_time = time.time()\n",
    "\n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    if valid_loss < BEST_VALID_LOSS:\n",
    "        BEST_VALID_LOSS = valid_loss\n",
    "        torch.save(att_model, att_model_name)\n",
    "        print(f'***Save Best Model {att_model_name}***')\n",
    "    \n",
    "    print(f'Epoch: {epoch :02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 0.5882109670262587 | Test Acc: 0.8019131882148877\n"
     ]
    }
   ],
   "source": [
    "att_model = torch.load(att_model_name)\n",
    "test_loss, test_acc = evaluate(att_model, DEVICE, test_data, criterion)\n",
    "print('Test Loss: {0} | Test Acc: {1}'.format(test_loss, test_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deploy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MyAttentionModel(\n",
       "  (embedding): Embedding(16000, 100, padding_idx=0)\n",
       "  (q): Linear(in_features=100, out_features=100, bias=False)\n",
       "  (k): Linear(in_features=100, out_features=100, bias=False)\n",
       "  (v): Linear(in_features=100, out_features=100, bias=False)\n",
       "  (fc): Linear(in_features=100, out_features=1, bias=False)\n",
       "  (dropout): Dropout(p=0.2)\n",
       ")"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "att_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(model, device, x):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        x = x.to(device)\n",
    "        y = model(x)\n",
    "        print(y)\n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 56])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = test_data[0][0][0].unsqueeze(0)\n",
    "x.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gt = test_data[0][-1][0]\n",
    "gt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-10.3408], device='cuda:1')\n"
     ]
    }
   ],
   "source": [
    "y = predict(att_model, DEVICE, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_y = torch.sigmoid(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1], device='cuda:1', dtype=torch.uint8)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p_y < 0.5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## add positional encodings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PosAttentionModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_dim, output_size, pad_idx):\n",
    "        super(PosAttentionModel, self).__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)\n",
    "        initrange = 0.1\n",
    "        self.embedding.weight.data.uniform_(-initrange, initrange)\n",
    "        # 权重计算 q, v, k\n",
    "        self.q = nn.Linear(embed_dim, embed_dim, bias=False)\n",
    "        self.k = nn.Linear(embed_dim, embed_dim, bias=False)\n",
    "        self.v = nn.Linear(embed_dim, embed_dim, bias=False)\n",
    "        self.fc = nn.Linear(embed_dim, output_size, bias=False)\n",
    "        self.dropout = nn.Dropout(0.2)\n",
    "        \n",
    "        \n",
    "    def forward(self, text):\n",
    "        # [batch, seq_len] -> [batch, seq_len, emb_dim]\n",
    "        embed = self.embedding(text)\n",
    "        max_len, embed_dim = embed.size()[1], embed.size(2)\n",
    "        pe = self.get_pe(max_len, embed_dim)\n",
    "        # embed += Variable(self.pe[:, :x.size(1)],requires_grad=False)\n",
    "        embed += pe\n",
    "        # [batch, seq_len, emb_dim] -> [batch, seq_len, embed_dim]?\n",
    "        q_vec = self.q(embed) \n",
    "        k_vec = self.k(embed)\n",
    "        v_vec = self.v(embed)\n",
    "        # 算句子Attention平均值\n",
    "        h_attn = self.attention(q_vec, k_vec, v_vec)   # [batch, seq_len, emb_dim]\n",
    "        h_attn += embed\n",
    "        # 平均值\n",
    "        # [batch, seq_len, emb_dim] -> [batch, emb_dim]  # 每个句子求平均值得到一个词向量\n",
    "        h_attn = torch.sum(h_attn, dim=1).squeeze()\n",
    "        # [batch, emb_dim] --> [batch, output_size]\n",
    "        out = self.fc(self.dropout(h_attn))\n",
    "        return out\n",
    "    \n",
    "    def attention(self, q, k, v):\n",
    "        \"\"\"计算attention权重\"\"\"\n",
    "        d_k = k.size(-1)    # embed_dim\n",
    "        # [batch, seq_len, emb_dim] -> [batch, seq_len, seq_len]\n",
    "        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)\n",
    "        # [batch, seq_len, seq_len] ->[batch, seq_len, seq_len]\n",
    "        attn = F.softmax(scores, dim=-1)\n",
    "        # 计算context值 \n",
    "        # [batch, seq_len, seq_len] -> [batch, seq_len, emb_dim]\n",
    "        attn_x = torch.matmul(attn, v)\n",
    "        return attn_x \n",
    "    \n",
    "    @property\n",
    "    def get_pe(self, max_len, embed_dim):\n",
    "        pe = torch.zeros(max_len, embed_dim)\n",
    "        position = torch.arange(0, max_len).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(0, embed_dim, 2) *\n",
    "                             -(math.log(10000.0) / embed_dim))\n",
    "        pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        pe[:, 1::2] = torch.cos(position * div_term)\n",
    "        pe = pe.unsqueeze(0)\n",
    "        return pe\n",
    "    \n",
    "    def get_embed_weight(self):\n",
    "        \"\"\"获取embedding层参数\"\"\"\n",
    "        return self.embedding.weight.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型有1,513,100个可调节参数, 大约5.7720184326171875 M.\n"
     ]
    }
   ],
   "source": [
    "pos_model = MyAttentionModel(INPUT_DIM, EMBEDDING_DIM, OUTPUT_DIM, PAD_IDX)\n",
    "print(f'模型有{count_parameters(pos_model):,}个可调节参数, 大约{count_parameters(pos_model)*4/1024/1024} M.')\n",
    "\n",
    "pos_model = pos_model.to(DEVICE)\n",
    "\n",
    "optimizer = optim.Adam(pos_model.parameters(), lr=LEARNING_RATE)\n",
    "criterion = nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "***Save Best Model pos-attention-wordavg.pth***\n",
      "Epoch: 01 | Epoch Time: 0m 11s\n",
      "\tTrain Loss: 0.424 | Train Acc: 82.91%\n",
      "\t Val. Loss: 0.982 |  Val. Acc: 78.24%\n",
      "Epoch: 02 | Epoch Time: 0m 11s\n",
      "\tTrain Loss: 0.520 | Train Acc: 86.70%\n",
      "\t Val. Loss: 1.653 |  Val. Acc: 80.36%\n",
      "Epoch: 03 | Epoch Time: 0m 11s\n",
      "\tTrain Loss: 2.995 | Train Acc: 83.55%\n",
      "\t Val. Loss: 3.424 |  Val. Acc: 72.43%\n",
      "Epoch: 04 | Epoch Time: 0m 10s\n",
      "\tTrain Loss: 0.534 | Train Acc: 89.54%\n",
      "\t Val. Loss: 1.370 |  Val. Acc: 78.57%\n",
      "Epoch: 05 | Epoch Time: 0m 12s\n",
      "\tTrain Loss: 0.427 | Train Acc: 91.87%\n",
      "\t Val. Loss: 2.589 |  Val. Acc: 72.21%\n"
     ]
    }
   ],
   "source": [
    "pos_model_name = 'pos-attention-wordavg.pth'\n",
    "BEST_VALID_LOSS = float('inf')\n",
    "EPOCHS = 5\n",
    "\n",
    "for epoch in range(1, EPOCHS+1):\n",
    "    start_time = time.time()\n",
    "    train_loss, train_acc = train(pos_model, DEVICE, train_data, optimizer, criterion)\n",
    "    valid_loss, valid_acc = evaluate(pos_model, DEVICE, eval_data, criterion)\n",
    "    end_time = time.time()\n",
    "\n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    if valid_loss < BEST_VALID_LOSS:\n",
    "        BEST_VALID_LOSS = valid_loss\n",
    "        torch.save(pos_model, pos_model_name)\n",
    "        print(f'***Save Best Model {pos_model_name}***')\n",
    "    \n",
    "    print(f'Epoch: {epoch :02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 1.033877345030768 | Test Acc: 0.7567869027455648\n"
     ]
    }
   ],
   "source": [
    "pos_model = torch.load(pos_model_name)\n",
    "test_loss, test_acc = evaluate(pos_model, DEVICE, test_data, criterion)\n",
    "print('Test Loss: {0} | Test Acc: {1}'.format(test_loss, test_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
